package com.aimashi.dynamicmongo.config;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.mongodb.core.MongoDatabaseFactorySupport;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.SimpleMongoClientDatabaseFactory;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.Map;
/**
 * Created by AI码师 on 2019/4/19.
 * 关注公众号【AI码师】领取2021最新面试资料一份（很全）
 * @return
 */
@Component
@Aspect
public class MongoSwitch {
  private final Logger logger = LoggerFactory.getLogger(MongoSwitch.class);

  @Autowired private MongoDatabaseFactorySupport mongoDbFactory;
  private final Map<String, MongoDatabaseFactorySupport> templateMuliteMap = new HashMap<>();
  // 获取配置文件的副本集连接
  @Value("${spring.data.mongodb.uri}")
  private String uri;

  //	@Pointcut("@annotation(com.pig4cloud.pig.common.log.annotation.MongoLog)")
  @Pointcut("execution(public * com.aimashi.dynamicmongo.config.MongotemplteService.*(..))")
  public void routeMongoDB() {}

  @Around("routeMongoDB()")
  public Object routeMongoDB(ProceedingJoinPoint joinPoint) {
    Object result = null;
    // 获取需要访问的项目数据库
    String dbName = (String) joinPoint.getArgs()[0];
    Object o = joinPoint.getTarget();
    Field[] fields = o.getClass().getDeclaredFields();
    MultiMongoTemplate mongoTemplate = null;

    try {
      for (Field field : fields) {
        field.setAccessible(true);

        Class fieldclass = field.getType();
        // 找到Template的变量
        if (fieldclass == MongoTemplate.class || fieldclass == MultiMongoTemplate.class) {
          // 查找项目对应的MongFactory
          SimpleMongoClientDatabaseFactory simpleMongoClientDbFactory = null;
          // 实例化
          if (templateMuliteMap.get(dbName) == null) { // 替换数据源
            simpleMongoClientDbFactory =
                new SimpleMongoClientDatabaseFactory(this.uri.replace("#", dbName));
            templateMuliteMap.put(dbName, simpleMongoClientDbFactory);
          } else {
            simpleMongoClientDbFactory =
                (SimpleMongoClientDatabaseFactory) templateMuliteMap.get(dbName);
          }
          // 如果第一次，赋值成自定义的MongoTemplate子类
          if (fieldclass == MongoTemplate.class) {
            mongoTemplate = new MultiMongoTemplate(simpleMongoClientDbFactory);
          } else if (fieldclass == MultiMongoTemplate.class) {
            Object fieldObject = field.get(o);
            mongoTemplate = (MultiMongoTemplate) fieldObject;
          }
          // 设置MongoFactory
          mongoTemplate.setMongoDbFactory(simpleMongoClientDbFactory);
          // 重新赋值
          field.set(o, mongoTemplate);
          break;
        }
      }
      try {
        result = joinPoint.proceed();
        // 清理ThreadLocal的变量
        mongoTemplate.removeMongoDbFactory();
      } catch (Throwable t) {
        logger.error("", t);
        mongoTemplate.removeMongoDbFactory();
      }
    } catch (Exception e) {
      logger.error("", e);
    }

    return result;
  }
}
